{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Basic Seq2Seq Practice\n",
    "### Source code come from  :  '從零開始的 Sequence to Sequence ' \n",
    "Article： http://zake7749.github.io/2017/09/28/Sequence-to-Sequence-tutorial/ <br>\n",
    "Github: https://github.com/zake7749/Sequence-to-Sequence-101"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## seq2seq Moder!\n",
    "![](./image/seq2seq.png)\n",
    "\n",
    "##### Seq2Seq Model: 由兩個Sequential model組成，輸入和輸出都可以是序列資料，也被稱作 Encoder-Decoder framework。\n",
    "1. Sequential model擅長處理有序列特徵的資料(文字,語音,時間序)，模型常見的基本組成就是RNN、LSTM、GRU。\n",
    "2. Encoder: 把輸入的文字轉換成機器理解的context vector。\n",
    "3. Decoder: 把context vector轉換成我們能理解的文字。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import random\n",
    "import torch.nn as nn\n",
    "from torch.autograd import Variable\n",
    "from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence\n",
    "\n",
    "'''Self define package by author (check author's github)'''\n",
    "from dataset.DataHelper import DataTransformer\n",
    "from config import config"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building Seq2Seq Model (Combine encoder & decoder)\n",
    "1. Input encoder & decoder model and declared <br>\n",
    "2. In forward: input date / runnung encoder forward / runnung decoder forward\n",
    "3. Evaluation: Using test data => Running Encoder forward => Running Decoder evaluation "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Seq2Seq(nn.Module):\n",
    "    def __init__(self, encoder, decoder):\n",
    "        super(Seq2Seq, self).__init__()\n",
    "        # Input Encoder、Decoder model and declare.\n",
    "        self.encoder = encoder\n",
    "        self.decoder = decoder\n",
    "        \n",
    "\n",
    "    def forward(self, inputs, targets):\n",
    "        # Input training data\n",
    "        input_vars, input_lengths = inputs\n",
    "        \n",
    "        #Running encoder\n",
    "        encoder_outputs, encoder_hidden = self.encoder.forward(input_vars, input_lengths)\n",
    "        \n",
    "        # Running decoder\n",
    "        decoder_outputs, decoder_hidden = self.decoder.forward(context_vector=encoder_hidden, targets=targets)\n",
    "        return decoder_outputs, decoder_hidden\n",
    "\n",
    "    \n",
    "    \n",
    "    def evaluation(self, inputs):\n",
    "        # Input test data\n",
    "        input_vars, input_lengths = inputs\n",
    "        \n",
    "        # \n",
    "        encoder_outputs, encoder_hidden = self.encoder(input_vars, input_lengths)\n",
    "        decoded_sentence = self.decoder.evaluation(context_vector=encoder_hidden)\n",
    "        return decoded_sentence"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Encoder Model\n",
    "將一組序列(input)用 Embedding 轉成向量，並在 RNN 最後一個時間點的輸出 hidden 做為 context vector。 <br>\n",
    "\n",
    "* 補充 PackedSequence 物件：<br>\n",
    "1. 在 Recurrent neural network 裡，由於每筆資料的 input 和 output 在長度會有所不同，無法用 batch 的方式來 train ，在 pytorch 有一個特別的 class 叫 PackedSequence，用來幫忙解決這個問題。<br>\n",
    "http://www.cnblogs.com/lindaxin/p/8052043.html <br>\n",
    "![](./image/padd.png)\n",
    "2. 用 torch.nn.utils.rnn.pack_padded_sequence將 Variable 轉換成 PackedSequence  ;  如果要轉換回 Variable ，要用torch.nn.utils.rnn.pad_packed_sequence這個函式。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VanillaEncoder(nn.Module):\n",
    "\n",
    "    def __init__(self, vocab_size, embedding_size, output_size):\n",
    "        \"\"\"Define layers for a vanilla rnn encoder\"\"\"\n",
    "        super(VanillaEncoder, self).__init__()\n",
    "\n",
    "        self.vocab_size = vocab_size\n",
    "        self.embedding = nn.Embedding(vocab_size, embedding_size)\n",
    "        self.gru = nn.GRU(embedding_size, output_size) # GRU: one kind of rnn model\n",
    "\n",
    "    def forward(self, input_seqs, input_lengths, hidden=None):\n",
    "        # input to vector(variable)\n",
    "        embedded = self.embedding(input_seqs)\n",
    "        # vector(variable) to packed sequence (become same length)\n",
    "        packed = pack_padded_sequence(embedded, input_lengths)\n",
    "        # Runnung RNN\n",
    "        packed_outputs, hidden = self.gru(packed, hidden)\n",
    "        # packed sequence to vector(variable) \n",
    "        outputs, output_lengths = pad_packed_sequence(packed_outputs)\n",
    "        return outputs, hidden\n",
    "\n",
    "    def forward_a_sentence(self, inputs, hidden=None):\n",
    "        \"\"\"Deprecated, forward 'one' sentence at a time which is bad for gpu utilization\"\"\"\n",
    "        embedded = self.embedding(inputs)\n",
    "        outputs, hidden = self.gru(embedded, hidden)\n",
    "        return outputs, hidden"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Decoder Model\n",
    "\n",
    "##### 和 Decoder 類似 Encoder 只是他的 input 除了來自 Encoder 之外，每一個時間的的 output 也會變成下一個時間點的input，以下重點：\n",
    "\n",
    "* Flow1. first input: SOS (Start of sentence) <br> \n",
    "* Flow2. first hidden : Pass the context vector <br>\n",
    "* Flow3. Decoder每個時間點的output當作下個時間點input，利用 [forward_step] 來執行RNN，和 Encoder類似都是 GRU，只是多出每個時間點的 output。<br>\n",
    "* 補充訓練小技巧：teacher_forcing_ratio 是個常數機率（本例子設0.5），用於隨機將 Decoder下個時間的的 input換成是真正 Label，以幫助訓練的穩定性。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "class VanillaDecoder(nn.Module):\n",
    "    def __init__(self, hidden_size, output_size, max_length, teacher_forcing_ratio, sos_id, use_cuda):\n",
    "        \"\"\"Define layers for a vanilla rnn decoder\"\"\"\n",
    "        super(VanillaDecoder, self).__init__()\n",
    "\n",
    "        self.hidden_size = hidden_size\n",
    "        self.output_size = output_size\n",
    "        self.embedding = nn.Embedding(output_size, hidden_size)\n",
    "        self.gru = nn.GRU(hidden_size, hidden_size)\n",
    "        self.out = nn.Linear(hidden_size, output_size)\n",
    "        self.log_softmax = nn.LogSoftmax()  # work with NLLLoss = CrossEntropyLoss\n",
    "\n",
    "        self.max_length = max_length\n",
    "        self.teacher_forcing_ratio = teacher_forcing_ratio\n",
    "        self.sos_id = sos_id\n",
    "        self.use_cuda = use_cuda\n",
    "        \n",
    "        \n",
    "    def forward_step(self, inputs, hidden):\n",
    "        '''Run GRU in each time step:\n",
    "           和 Encoder類似都是 GRU，只是多出每個時間點的 output'''\n",
    "        # inputs: (time_steps=1, batch_size)\n",
    "        batch_size = inputs.size(1)\n",
    "        embedded = self.embedding(inputs)\n",
    "        embedded.view(1, batch_size, self.hidden_size)  # S = T(1) x B x N\n",
    "        rnn_output, hidden = self.gru(embedded, hidden)  # S = T(1) x B x H\n",
    "        rnn_output = rnn_output.squeeze(0)  # squeeze the time dimension\n",
    "        output = self.log_softmax(self.out(rnn_output))  # S = B x O\n",
    "        # self.out： nn.Linear(data) / Ax+b \n",
    "        # self.log_softmax = nn.LogSoftmax()  # work with NLLLoss = CrossEntropyLoss\n",
    "        return output, hidden\n",
    "    \n",
    "    \n",
    "    ### 重點流程：\n",
    "    def forward(self, context_vector, targets):\n",
    "        # Prepare variable for decoder on time_step_0\n",
    "        target_vars, target_lengths = targets\n",
    "        batch_size = context_vector.size(1)\n",
    "        \n",
    "        ''' Flow1. \n",
    "            first input: SOS (Start of sentence) ''' \n",
    "        decoder_input = Variable(torch.LongTensor([[self.sos_id] * batch_size]))\n",
    "        \n",
    "        \n",
    "        ''' Flow2.\n",
    "            first hidden : context vector/ come frome Encoder '''\n",
    "        decoder_hidden = context_vector\n",
    "\n",
    "        max_target_length = max(target_lengths)\n",
    "        decoder_outputs = Variable(torch.zeros(\n",
    "            max_target_length,\n",
    "            batch_size,\n",
    "            self.output_size\n",
    "        ))  # (time_steps, batch_size, vocab_size)\n",
    "\n",
    "        if self.use_cuda:\n",
    "            decoder_input = decoder_input.cuda()\n",
    "            decoder_outputs = decoder_outputs.cuda()\n",
    "\n",
    "            \n",
    "        '''補充訓練小技巧：\n",
    "          teacher_forcing_ratio 是個常數機率（本例子設0.5），用於隨機將 Decoder下個時間的的 input\n",
    "          換成是真正 Label，以幫助訓練的穩定性。\n",
    "        ''' \n",
    "        use_teacher_forcing = True if random.random() > self.teacher_forcing_ratio else False\n",
    "        \n",
    "        \n",
    "        ''' Flow3.\n",
    "        Decoder每個時間點的 output 當作下個時間點 input，利用 [forward_step] 來執行RNN，和 Encoder 類似\n",
    "        都是 GRU，只是多出每個時間點的 output。'''\n",
    "        for t in range(max_target_length):\n",
    "            decoder_outputs_on_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
    "            decoder_outputs[t] = decoder_outputs_on_t\n",
    "            \n",
    "            # 同上訓練小技巧\n",
    "            if use_teacher_forcing:\n",
    "                decoder_input = target_vars[t].unsqueeze(0)\n",
    "                # 一定機率給真實Label回去訓練\n",
    "            else:\n",
    "                decoder_input = self._decode_to_index(decoder_outputs_on_t)\n",
    "                # 一定機率給 decoder_outputs_on_t\n",
    "\n",
    "        return decoder_outputs, decoder_hidden\n",
    "    \n",
    "    \n",
    "    \n",
    "    \n",
    "\n",
    "    def evaluation(self, context_vector):\n",
    "        batch_size = context_vector.size(1) # get the batch size\n",
    "        decoder_input = Variable(torch.LongTensor([[self.sos_id] * batch_size]))\n",
    "        decoder_hidden = context_vector\n",
    "\n",
    "        decoder_outputs = Variable(torch.zeros(\n",
    "            self.max_length,\n",
    "            batch_size,\n",
    "            self.output_size\n",
    "        ))  # (time_steps, batch_size, vocab_size)\n",
    "\n",
    "        if self.use_cuda:\n",
    "            decoder_input = decoder_input.cuda()\n",
    "            decoder_outputs = decoder_outputs.cuda()\n",
    "\n",
    "        # Unfold the decoder RNN on the time dimension\n",
    "        for t in range(self.max_length):\n",
    "            decoder_outputs_on_t, decoder_hidden = self.forward_step(decoder_input, decoder_hidden)\n",
    "            decoder_outputs[t] = decoder_outputs_on_t\n",
    "            decoder_input = self._decode_to_index(decoder_outputs_on_t)  # select the former output as input\n",
    "\n",
    "        return self._decode_to_indices(decoder_outputs)\n",
    "     \n",
    "    \n",
    "\n",
    "    def _decode_to_index(self, decoder_output):\n",
    "        \"\"\"\n",
    "        evaluate on the logits, get the index of top1:\n",
    "        param decoder_output: S = B x V or T x V\n",
    "        \"\"\"\n",
    "        value, index = torch.topk(decoder_output, 1)\n",
    "        # Returns the k largest elements of the given input tensor along a given dimension.\n",
    "        index = index.transpose(0, 1)  # S = 1 x B, 1 is the index of top1 class\n",
    "        if self.use_cuda:\n",
    "            index = index.cuda()\n",
    "        return index\n",
    "    \n",
    "    \n",
    "\n",
    "    def _decode_to_indices(self, decoder_outputs):\n",
    "        \"\"\"\n",
    "        Evaluate on the decoder outputs(logits), find the top 1 indices.\n",
    "        Please confirm that the model is on evaluation mode if dropout/batch_norm layers have been added\n",
    "        :param decoder_outputs: the output sequence from decoder, shape = T x B x V \n",
    "        \"\"\"\n",
    "        decoded_indices = []\n",
    "        batch_size = decoder_outputs.size(1)\n",
    "        decoder_outputs = decoder_outputs.transpose(0, 1)  # S = B x T x V\n",
    "\n",
    "        for b in range(batch_size):\n",
    "            top_ids = self._decode_to_index(decoder_outputs[b])\n",
    "            decoded_indices.append(top_ids.data[0])\n",
    "        return decoded_indices"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Building Training Object\n",
    "1. init: initializing seq2seq model, dataset information, optimizer setting\n",
    "2. train method: Training seq2seq model [num_epochs] times, with [mini_batches]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "class Trainer(object):\n",
    "\n",
    "    def __init__(self, model, data_transformer, learning_rate, use_cuda,\n",
    "                 checkpoint_name=config.checkpoint_name,\n",
    "                 teacher_forcing_ratio=config.teacher_forcing_ratio):\n",
    "\n",
    "        self.model = model #seq2seq model\n",
    "\n",
    "        # record some information about dataset\n",
    "        self.data_transformer = data_transformer\n",
    "        self.vocab_size = self.data_transformer.vocab_size\n",
    "        self.PAD_ID = self.data_transformer.PAD_ID\n",
    "        self.use_cuda = use_cuda\n",
    "\n",
    "        # optimizer setting\n",
    "        self.learning_rate = learning_rate\n",
    "        self.optimizer= torch.optim.Adam(self.model.parameters(), lr=learning_rate)\n",
    "        self.criterion = torch.nn.NLLLoss(ignore_index=self.PAD_ID, size_average=True)\n",
    "\n",
    "        self.checkpoint_name = checkpoint_name\n",
    "        \n",
    "        \n",
    "\n",
    "    def train(self, num_epochs, batch_size, pretrained=False):\n",
    "\n",
    "        if pretrained:\n",
    "            self.load_model()\n",
    "\n",
    "        step = 0\n",
    "\n",
    "        for epoch in range(0, num_epochs):\n",
    "            mini_batches = self.data_transformer.mini_batches(batch_size=batch_size)\n",
    "            for input_batch, target_batch in mini_batches:\n",
    "                self.optimizer.zero_grad()\n",
    "                \n",
    "                # Call seq2seq model to training\n",
    "                decoder_outputs, decoder_hidden = self.model(input_batch, target_batch)\n",
    "\n",
    "                # calculate the loss and back prop.\n",
    "                cur_loss = self.get_loss(decoder_outputs, target_batch[0])\n",
    "\n",
    "                # logging\n",
    "                step += 1\n",
    "                if step % 50 == 0:\n",
    "                    print(\"Step:\", step, \"loss of char: \", cur_loss.data[0])\n",
    "                    self.save_model()\n",
    "                cur_loss.backward()\n",
    "\n",
    "                # optimizing parameter\n",
    "                # torch.optim.Adam(self.model.parameters(), lr=learning_rate)\n",
    "                self.optimizer.step()\n",
    "        self.save_model()\n",
    "\n",
    "        \n",
    "    def masked_nllloss(self):\n",
    "        # Deprecated in PyTorch 2.0, can be replaced by ignore_index\n",
    "        # define the masked NLLoss\n",
    "        weight = torch.ones(self.vocab_size)\n",
    "        weight[self.PAD_ID] = 0\n",
    "        if self.use_cuda:\n",
    "            weight = weight.cuda()\n",
    "        return torch.nn.NLLLoss(weight=weight).cuda()\n",
    "    \n",
    "\n",
    "    def get_loss(self, decoder_outputs, targets):\n",
    "        b = decoder_outputs.size(1)\n",
    "        t = decoder_outputs.size(0)\n",
    "        targets = targets.contiguous().view(-1)  # S = (B*T)\n",
    "        decoder_outputs = decoder_outputs.view(b * t, -1)  # S = (B*T) x V\n",
    "        return self.criterion(decoder_outputs, targets)\n",
    "    \n",
    "\n",
    "    def save_model(self):\n",
    "        torch.save(self.model.state_dict(), self.checkpoint_name)\n",
    "        print(\"Model has been saved as %s.\\n\" % self.checkpoint_name)\n",
    "\n",
    "    def load_model(self):\n",
    "        self.model.load_state_dict(torch.load(self.checkpoint_name))\n",
    "        print(\"Pretrained model has been loaded.\\n\")\n",
    "\n",
    "    def tensorboard_log(self):\n",
    "        pass\n",
    "    \n",
    "    \n",
    "\n",
    "    def evaluate(self, words):\n",
    "        # make sure that words is list\n",
    "        if type(words) is not list:\n",
    "            words = [words]\n",
    "\n",
    "        # transform word to index-sequence\n",
    "        eval_var = self.data_transformer.evaluation_batch(words=words)\n",
    "        decoded_indices = self.model.evaluation(eval_var)\n",
    "        results = []\n",
    "        for indices in decoded_indices:\n",
    "            results.append(self.data_transformer.vocab.indices_to_sequence(indices))\n",
    "        return results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Function to Training Model \n",
    "1. Declare encoder model\n",
    "2. Declare decoder model\n",
    "3. Declare seq2seq midel (by encoder & decoder model)\n",
    "4. Declare training object\n",
    "5. Running training object"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 28,
   "metadata": {},
   "outputs": [],
   "source": [
    "def main():\n",
    "    '''\n",
    "    Self defined package by author (check author's github)\n",
    "    from dataset.DataHelper import DataTransformer\n",
    "    from config import config\n",
    "    '''\n",
    "    use_cuda=False\n",
    "    data_transformer = DataTransformer(config.dataset_path, use_cuda=use_cuda) # use_cuda=config.use_cuda\n",
    "\n",
    "    # 1. Declare encoder model\n",
    "    vanilla_encoder = VanillaEncoder(vocab_size=data_transformer.vocab_size,\n",
    "                                     embedding_size=config.encoder_embedding_size,\n",
    "                                     output_size=config.encoder_output_size)\n",
    "    # 2. Declare decoder model\n",
    "    vanilla_decoder = VanillaDecoder(hidden_size=config.decoder_hidden_size,\n",
    "                                     output_size=data_transformer.vocab_size,\n",
    "                                     max_length=data_transformer.max_length,\n",
    "                                     teacher_forcing_ratio=config.teacher_forcing_ratio,\n",
    "                                     sos_id=data_transformer.SOS_ID,\n",
    "                                     use_cuda=config.use_cuda)\n",
    "    if config.use_cuda:\n",
    "        vanilla_encoder = vanilla_encoder.cuda()\n",
    "        vanilla_decoder = vanilla_decoder.cuda()\n",
    "        \n",
    "\n",
    "    # 3. Declare seq2seq midel (by encoder & decoder model)\n",
    "    seq2seq = Seq2Seq(encoder=vanilla_encoder,\n",
    "                      decoder=vanilla_decoder)\n",
    "\n",
    "    # 4. Declare training object\n",
    "    trainer = Trainer(seq2seq, data_transformer, config.learning_rate, config.use_cuda)\n",
    "    \n",
    "    # 5. Running training object\n",
    "    trainer.train(num_epochs=config.num_epochs, batch_size=config.batch_size, pretrained=False)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Starting to run    \n",
    "if __name__ == \"__main__\":\n",
    "    main()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Reference:\n",
    "* 科技大擂台 Pytorch Seq2Seq 篇: \n",
    "https://fgc.stpi.narl.org.tw/activity/videoDetail/4b1141305df38a7c015e194f22f8015b\n",
    "\n",
    "* PyTorch Document: \n",
    "http://pytorch.org/tutorials/intermediate/seq2seq_translation_tutorial.htm\n",
    "\n",
    "###### ==> Further reading: Conbine Attention"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_batch with shape: torch.Size([14, 5])\n",
      "(Variable containing:\n",
      "    6    19     9    20     8\n",
      "   26     9    24     9    20\n",
      "    4    24     6    14     1\n",
      "   13    12    11     1     2\n",
      "    9    22     1     2     2\n",
      "    7     1     2     2     2\n",
      "   13     2     2     2     2\n",
      "   11     2     2     2     2\n",
      "   12     2     2     2     2\n",
      "   10     2     2     2     2\n",
      "    9     2     2     2     2\n",
      "   13     2     2     2     2\n",
      "   16     2     2     2     2\n",
      "    1     2     2     2     2\n",
      "[torch.LongTensor of size 14x5]\n",
      ", [14, 6, 5, 4, 3])\n"
     ]
    }
   ],
   "source": [
    "data_transformer = DataTransformer(config.dataset_path, use_cuda=False)\n",
    "text_data = data_transformer.mini_batches(batch_size=5)\n",
    "\n",
    "for input_batch, target_batch in text_data:\n",
    "    print('input_batch with shape:', input_batch[0].shape)\n",
    "    print(input_batch)\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 47,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "embedding_size:  8\n",
      "vocab_size    :  30\n"
     ]
    }
   ],
   "source": [
    "embedding_size= 8\n",
    "vocab_size=data_transformer.vocab_size\n",
    "EM = nn.Embedding(vocab_size, embedding_size)\n",
    "print('embedding_size: ', embedding_size)\n",
    "print('vocab_size    : ', vocab_size)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 54,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Variable containing:\n",
       "(0 ,.,.) = \n",
       " -2.8272 -0.8495 -0.7722 -1.6469 -1.4244 -1.0290  0.4223  2.0485\n",
       " -0.5904  0.5569  0.2864 -0.7210  1.0009 -0.9796 -0.5699  0.1710\n",
       "  1.2649 -1.5756 -0.1523  1.3332 -0.3693 -0.6410  2.5663  1.0958\n",
       " -0.7273 -0.1256 -1.8895  0.4864 -0.3066 -0.0811 -0.0353  0.0683\n",
       " -0.8637 -0.4147 -0.9887 -1.0099  0.7571 -0.8994  1.3255  0.9284\n",
       "\n",
       "(1 ,.,.) = \n",
       "  0.1070  0.8086  0.6861 -0.0120  0.4722  1.7426  1.9810  1.5739\n",
       "  1.2649 -1.5756 -0.1523  1.3332 -0.3693 -0.6410  2.5663  1.0958\n",
       " -0.7698 -1.1971 -0.4080 -1.6318  0.3038  1.4950  0.4344 -0.9362\n",
       "  1.2649 -1.5756 -0.1523  1.3332 -0.3693 -0.6410  2.5663  1.0958\n",
       " -0.7273 -0.1256 -1.8895  0.4864 -0.3066 -0.0811 -0.0353  0.0683\n",
       "\n",
       "(2 ,.,.) = \n",
       " -0.7275 -0.0046 -0.3437 -0.7898  1.7581 -0.9436 -0.2465  0.7758\n",
       " -0.7698 -1.1971 -0.4080 -1.6318  0.3038  1.4950  0.4344 -0.9362\n",
       " -2.8272 -0.8495 -0.7722 -1.6469 -1.4244 -1.0290  0.4223  2.0485\n",
       "  1.4335  1.1976 -0.1213  0.2545  0.6740  0.6193  0.5124  0.2051\n",
       " -1.4960  1.6445 -0.5717 -0.2748 -1.4087 -0.3116 -1.2721  1.0730\n",
       "\n",
       "(3 ,.,.) = \n",
       " -0.3116 -0.0417  0.0409 -0.8809  0.0966  1.9488  1.3507 -0.0887\n",
       "  0.0232 -0.9956 -0.1627  0.1138 -0.7270 -0.9151  0.8247  0.4661\n",
       " -0.4732  1.0304 -1.0002 -0.5305  0.8912  0.9431  2.6662  0.8570\n",
       " -1.4960  1.6445 -0.5717 -0.2748 -1.4087 -0.3116 -1.2721  1.0730\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(4 ,.,.) = \n",
       "  1.2649 -1.5756 -0.1523  1.3332 -0.3693 -0.6410  2.5663  1.0958\n",
       "  0.6948 -2.7214  1.0446  0.4498  1.5954 -0.1916 -0.0311 -0.9364\n",
       " -1.4960  1.6445 -0.5717 -0.2748 -1.4087 -0.3116 -1.2721  1.0730\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(5 ,.,.) = \n",
       "  0.3776  0.6509 -0.4561  0.5275  1.3300  1.4949 -0.5153 -0.0677\n",
       " -1.4960  1.6445 -0.5717 -0.2748 -1.4087 -0.3116 -1.2721  1.0730\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(6 ,.,.) = \n",
       " -0.3116 -0.0417  0.0409 -0.8809  0.0966  1.9488  1.3507 -0.0887\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(7 ,.,.) = \n",
       " -0.4732  1.0304 -1.0002 -0.5305  0.8912  0.9431  2.6662  0.8570\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(8 ,.,.) = \n",
       "  0.0232 -0.9956 -0.1627  0.1138 -0.7270 -0.9151  0.8247  0.4661\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(9 ,.,.) = \n",
       " -0.4502  1.5287  2.3401 -0.7399  0.6095  0.0631 -0.9637 -1.9484\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(10,.,.) = \n",
       "  1.2649 -1.5756 -0.1523  1.3332 -0.3693 -0.6410  2.5663  1.0958\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(11,.,.) = \n",
       " -0.3116 -0.0417  0.0409 -0.8809  0.0966  1.9488  1.3507 -0.0887\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(12,.,.) = \n",
       "  0.1462  0.9322  0.5609 -0.7818  1.2379  0.7313  0.2654  0.3523\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "\n",
       "(13,.,.) = \n",
       " -1.4960  1.6445 -0.5717 -0.2748 -1.4087 -0.3116 -1.2721  1.0730\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       " -0.0376  0.8335  0.0578  0.7986 -0.1603  0.2535 -0.2811 -1.1204\n",
       "[torch.FloatTensor of size 14x5x8]"
      ]
     },
     "execution_count": 54,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "input_vars, input_lengths = input_batch\n",
    "embedded = EM(input_vars)\n",
    "\n",
    "embedded     "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# PackedSequence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "\n",
       "(0 ,.,.) = \n",
       "   1   2\n",
       "   3   4\n",
       "   5   6\n",
       "\n",
       "(1 ,.,.) = \n",
       "   7   8\n",
       "   9  10\n",
       "  11  12\n",
       "[torch.FloatTensor of size 2x3x2]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.FloatTensor([[1,2,3],[4,5,6],[7,8,9], [10, 11, 12]]).view(2, 3, 2)\n",
    "# (epoch_length, batch_size, data_dim)\n",
    "x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "PackedSequence(data=\n",
       "  1   2\n",
       "  3   4\n",
       "  5   6\n",
       "  7   8\n",
       "  9  10\n",
       " 11  12\n",
       "[torch.FloatTensor of size 6x2]\n",
       ", batch_sizes=[3, 3])"
      ]
     },
     "execution_count": 6,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "pack_padded_sequence(x, [2, 2, 2])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 78,
   "metadata": {},
   "outputs": [],
   "source": [
    "def pack_padded_sequence(input, lengths, batch_first=False):\n",
    "    \"\"\"Packs a Variable containing padded sequences of variable length.\n",
    "\n",
    "    Input can be of size ``TxBx*`` where T is the length of the longest sequence\n",
    "    (equal to ``lengths[0]``), B is the batch size, and * is any number of\n",
    "    dimensions (including 0). If ``batch_first`` is True ``BxTx*`` inputs are expected.\n",
    "\n",
    "    Arguments:\n",
    "        input (Variable): padded batch of variable length sequences.\n",
    "        lengths (list[int]): list of sequences lengths of each batch element.\n",
    "        batch_first (bool, optional): if True, the input is expected in BxTx*\n",
    "            format.\n",
    "\n",
    "    Returns:\n",
    "        a :class:`PackedSequence` object\n",
    "    \"\"\"\n",
    "    if lengths[-1] <= 0:\n",
    "        raise ValueError(\"length of all samples has to be greater than 0, \"\n",
    "                         \"but found an element in 'lengths' that is <=0\")\n",
    "    if batch_first:\n",
    "        input = input.transpose(0, 1)\n",
    "\n",
    "    steps = []\n",
    "    batch_sizes = []\n",
    "    lengths_iter = reversed(lengths)\n",
    "    current_length = next(lengths_iter)\n",
    "    batch_size = input.size(1)\n",
    "    if len(lengths) != batch_size:\n",
    "        raise ValueError(\"lengths array has incorrect size\")\n",
    "\n",
    "    for step, step_value in enumerate(input, 1):\n",
    "        steps.append(step_value[:batch_size])\n",
    "        batch_sizes.append(batch_size)\n",
    "\n",
    "        while step == current_length:\n",
    "            try:\n",
    "                new_length = next(lengths_iter)\n",
    "            except StopIteration:\n",
    "                current_length = None\n",
    "                break\n",
    "\n",
    "            if current_length > new_length:  # remember that new_length is the preceding length in the array\n",
    "                raise ValueError(\"lengths array has to be sorted in decreasing order\")\n",
    "            batch_size -= 1\n",
    "            current_length = new_length\n",
    "        if current_length is None:\n",
    "            break\n",
    "    return PackedSequence(torch.cat(steps), batch_sizes)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.4"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
